/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "graph_optimizer/ub_fusion/buffer_fusion_pass_runner.h"
#include "common/fusion_statistic/fusion_statistic_writer.h"
#include "common/configuration.h"
#include "common/unknown_shape_util.h"
#include "register/graph_optimizer/fusion_common/graph_pass_util.h"
#include "common/util/platform_info.h"
#include "graph_optimizer/ub_fusion/tbe_pass/ub_pass_slice_info/ub_pass_slice_info_manager.h"

namespace fe {
namespace {
static const char *STREAM_LABEL = "_stream_label";
}

BufferFusionPassRunner::BufferFusionPassRunner(const string &name, BufferFusionPassBase *(*create_fn)(),
                                               const ScopeAllocatorPtr &scope_allocat,
                                               std::shared_ptr<ReachabilityMap> &reachability)
    : scope_allocator_ptr_(scope_allocat), reachability_(reachability) {
  SetName(name);
  buffer_fusion_pass_base_ptr_ = std::unique_ptr<BufferFusionPassBase>(create_fn());
  buffer_fusion_pass_base_ptr_->SetName(name);
  cube_op_type_ = {"BasicLSTMCellV2", "MatMul", "MatMulV2", "BatchMatMul", "GEMM", "ROIAlign",
               "Pooling", "FullyConnection", "Conv2DBackpropFilterD",
               "Conv2DBackpropFilter", "Conv2DBackpropInputD", "Conv2DBackpropInput",
               "Deconvolution", "Conv2DTransposeD", "Conv2D", "DepthwiseConv2D",
               "DepthwiseConv2DBackpropFilterD", "DepthwiseConv2DBackpropInputD",
               "ROIPooling", "BasicLSTMCellWeightGrad", "LRN", "PSROIPooling", "Conv3D",
               "Conv3DBackpropInputD", "Conv3DTransposeD", "Conv3DBackpropFilterD"
  };
}

BufferFusionPassRunner::~BufferFusionPassRunner() {
  for (auto pattern : patterns_) {
    if (pattern != nullptr) {
      delete (pattern);
      pattern = nullptr;
    }
  }
}

/*
 * @brief: get pattern and match it from graph
 * @param [in] graph: original graph
 * @return bool: the result deal with pattern matching
 */
Status BufferFusionPassRunner::Run(ge::ComputeGraph &graph) {
  // 1. get pattern info
  patterns_ = buffer_fusion_pass_base_ptr_->DefinePatterns();
  if (reachability_ == nullptr) {
    reachability_ = ReachabilityMap::Build(graph);
  }
  // 2. for all patterns
  for (BufferFusionPattern *pattern : patterns_) {
    if (pattern == nullptr) {
      continue;
    }
    string pattern_name = pattern->GetName();
    if (pattern->GetErrorCnt()) {
      REPORT_FE_ERROR("[SubGraphOpt][UB][Run] [%s] pattern has error config, error count is [%ld], and it's invalid.",
                      pattern_name.c_str(), pattern->GetErrorCnt());
      continue;
    }
    // 3. compare pattern op and graph op(include compare op type and TBE type)
    if (RunOnePattern(graph, *pattern) == false) {
      FE_LOGW("Run Pass[%s]Pattern[%s] failed.", GetName().c_str(), pattern->GetName().c_str());
      continue;
    }
  }
  return SUCCESS;
}

/*
 * @brief: check if is TVM type op
 * @param [in] node: node
 * @return bool: check result
 */
bool BufferFusionPassRunner::IsTbeOp(ge::NodePtr node) {
  FE_CHECK((node == nullptr),
           REPORT_FE_ERROR("[SubGraphOpt][UbFusion][IsTbeOP] null node in judging AICoreOp"), return false);
  int64_t type = 0;
  (void)ge::AttrUtils::GetInt(node->GetOpDesc(), ge::ATTR_NAME_IMPLY_TYPE, type);
  const bool res = (type == (int64_t)domi::ImplyType::TVM);
  if (res == false) {
    FE_LOGD("op [%s] is not tbe op", node->GetName().c_str());
    return false;
  }
  return true;
}

/*
 * @brief: check if is Valid op for UB fusion
 * @param [in] node: graph node
 * @return bool: check result
 */
bool BufferFusionPassRunner::NeedIgnoreOp(ge::NodePtr node) {
  FE_CHECK((node == nullptr),
           REPORT_FE_ERROR("[SubGraphOpt][UbFusion][NeedIgnOp] null node in judging ValidOp"), return false);

  // TBE core, fused pattern and hasn't fused op can not be ignore
  if (IsTbeOp(node) == false) {
    FE_LOGD("node [%s] is not tbe op, and will be skipped ub fusion.", node->GetName().c_str());
    return true;
  }

  if (NodeType(node) == false) {
    FE_LOGD("Fusion pattern of node [%s] is not supported, which cannot be fused with any other ops.",
        node->GetName().c_str());
    return true;
  }

  int64_t scope_id = 0;
  if (scope_allocator_ptr_->GetScopeAttr(node->GetOpDesc(), scope_id) == true) {
    FE_LOGD("node %s has been fused.", node->GetName().c_str());
    return true;
  }

  return false;
}

/*
 * @brief: get a node's type presented by a enum type
 * @param [in] node: graph node
 * @return OPTYPE: type of the node
 */
bool BufferFusionPassRunner::NodeType(ge::NodePtr node) {
  FE_CHECK((node == nullptr), FE_LOGD("null node in judging NodeType"), return false);
  string type = "";
  if (GetOpAttrType(node, type)) {
    if (std::find(OP_PATTERN_VEC.begin(), OP_PATTERN_VEC.end(), type) != OP_PATTERN_VEC.end()) {
      return true;
    } else {
      FE_LOGD("Node[%s]: can not find op pattern [%s] in OP_PATTERN_VEC.", node->GetName().c_str(), type.c_str());
      return false;
    }
  }
  return false;
}

/*
 * @brief: get the optype of a node
 * @param [in] node: graph node
 * @param [out] op_type: type represent by string
 * @return bool: get op type ok or not
 */
bool BufferFusionPassRunner::GetOpAttrType(ge::NodePtr node, string &op_type) {
  FE_CHECK((node == nullptr), REPORT_FE_ERROR("[SubGraphOpt][UbFusion][GetOpAttrType] node is nullptr."), return false);
  string name = node->GetName();
  auto key_str = name + "_pattern";

  if (ge::AttrUtils::GetStr(node->GetOpDesc(), key_str, op_type) == false) {
    FE_LOGD("node[%s] failed to get pattern [%s].", name.c_str(), key_str.c_str());
    return false;
  }

  if (op_type == "") {
    REPORT_FE_ERROR("[SubGraphOpt][UbFusion][GetOpAttrType] optype is empty for node name [%s].", name.c_str());
    return false;
  }

  return true;
}

bool BufferFusionPassRunner::IsOpTypeAny(const std::vector<string> &types) {
  return find(types.begin(), types.end(), TBE_PATTERN_OP_TYPE_ANY) != types.end();
}

bool BufferFusionPassRunner::IsOutputNode(const std::vector<string> &types) {
  return find(types.begin(), types.end(), TBE_PATTERN_OUTPUT_NODE) != types.end();
}

/*
 * @brief: check whether graph node is matched with pattern desc
 * @param [in] node: graph node
 * @param [in] op_desc: candidated pattern desc
 * @return bool: check result
 */
bool BufferFusionPassRunner::IsOpTypeExist(ge::NodePtr node, const BufferFusionOpDesc *op_desc) {
  string op_type = "";
  string name = node->GetName();
  const std::vector<string> types = op_desc->types;

  bool res = GetOpAttrType(node, op_type);
  if (!res) {
    if (IsOutputNode(types)) {
      FE_LOGD("Node:[%s] is output node.", node->GetName().c_str());
      return true;
    } else {
      FE_LOGD("Node:[%s] is not output node.", node->GetName().c_str());
      return false;
    }
  }

  if (find(types.begin(), types.end(), op_type) != types.end()) {
    return true;
  } else {
    // return true while the desc type is "OpTypeAny"
    if (IsOpTypeAny(types)) {
      FE_LOGD("Node:%s, Type:%s, Match Op Pattern ANY", name.c_str(), op_type.c_str());
      return true;
    }
    if (IsOutputNode(types)) {
      FE_LOGD("Node:%s, Type:%s, Match Op Pattern OUTNODE", name.c_str(), op_type.c_str());
      return true;
    }
    return false;
  }
}

/*
 * @brief: check whether node output size is same with candidate desc output
 * size
 * @param [in] node: graph node
 * @param [in] op_desc: candidated pattern desc
 * @return bool: check result
 */
bool BufferFusionPassRunner::SkipDiffSizeDesc(ge::NodePtr node, const BufferFusionOpDesc *op_desc,
                                              const string &pattern_name) {
  FE_CHECK(node == nullptr, REPORT_FE_ERROR("[SubGraphOpt][UbFusion][SkipDiffSizeDesc] node is null."), return false);
  FE_CHECK(op_desc == nullptr,
           REPORT_FE_ERROR("[SubGraphOpt][UbFusion][SkipDiffSizeDesc] opDesc is null."), return false);

  // single output node match single desc, and binary node match binary desc
  if (node->GetOutDataNodes().size() == 1 && op_desc->out_branch_type == TBE_OUTPUT_BRANCH_MULTI) {
    FE_LOGD("Node[%s]: the size of out_data_nodes is 1, but the out_brand_type is TBE_OUTPUT_BRANCH_MULTI, skip.",
            node->GetName().c_str());
    return true;
  }

  if (node->GetOutDataNodes().size() > 1 && op_desc->out_branch_type == TBE_OUTPUT_BRANCH_SINGLE) {
    // support common_rules2, outputs from conv and quant, conv is head node
    string op_type = "";
    if (GetOpAttrType(node, op_type)) {
      if (pattern_name == "TbeCommonRules2FusionPass" &&  op_type == "Convolution") {
        return false;
      }
    }
    FE_LOGD("Node[%s]: the size of out_data_nodes is > 1, but the out_brand_type is TBE_OUTPUT_BRANCH_SINGLE, skip.",
            node->GetName().c_str());
    return true;
  }

  return false;
}

bool BufferFusionPassRunner::SkipDiffShapeTypeDesc(ge::NodePtr node, const BufferFusionOpDesc *op_desc) {
  if (node == nullptr || op_desc == nullptr) {
    return true;
  }
  bool is_unknown_shape_op = IsFeSupportedDynamicOp(*(node->GetOpDesc()), true);
  if (op_desc->shape_type_rule == ONLY_SUPPORT_STATIC && is_unknown_shape_op) {
    FE_LOGD("Node[%s, %s] whose shape is dynamic shall be skipped for the buffer desc only supports static shape.",
            node->GetName().c_str(), node->GetType().c_str());
    return true;
  }

  if (op_desc->shape_type_rule == ONLY_SUPPORT_DYNAMIC && !is_unknown_shape_op) {
    FE_LOGD("Node[%s, %s] whose shape is static shall be skipped for the buffer desc only supports dynamic shape.",
            node->GetName().c_str(), node->GetType().c_str());
    return true;
  }
  return false;
}

/*
 * @brief: get current loop fusiton match status
 * @param [in] is_parallel: graph node is multi branch or single branch
 * @param [in] op_descs: candidated pattern desc
 * @param [in] usage: record whether desc has beed matched
 * @return bool: all current loop descs have beed matched or not
 */
bool BufferFusionPassRunner::GetCurrMatchStatus(bool is_parallel, std::vector<BufferFusionOpDesc *> op_descs,
                                                std::map<BufferFusionOpDesc *, bool> usage) {
  bool match_status = false;

  // check match status
  if (is_parallel) {
    match_status = true;
    for (auto op_desc : op_descs) {
      if (usage.find(op_desc) != usage.end()) {
        if (usage[op_desc] == false) {
          match_status = false;
          break;
        }
      }
    }
  } else {
    match_status = false;
    for (auto op_desc : op_descs) {
      if (usage.find(op_desc) != usage.end()) {
        if (usage[op_desc] == true) {
          match_status = true;
          break;
        }
      }
    }
  }

  return match_status;
}

/*
 * @brief: get pattern fusiton match status
 * @param [in] pattern: fusion pattern info
 * @return bool: the pattern has beed matched or not
 */
bool BufferFusionPassRunner::GetPatternMatchStatus(BufferFusionPattern &pattern) {
  std::map<int64_t, bool> group_status;
  // find same group desc match status
  for (auto desc : pattern.GetOpDescs()) {
    if (desc->types[0] == TBE_PATTERN_INPUT_NODE) {
      continue;
    }
    if (desc->group_id == TBE_PATTERN_GROUPID_INVALID) {
      continue;
    }
    if (group_status.find(desc->group_id) == group_status.end()) {
      group_status[desc->group_id] = false;
    }
    if (desc->repeate_curr >= desc->repeate_min) {
      group_status[desc->group_id] = true;
    }
  }
  // find all pattern descs matched status
  bool status = true;
  for (auto desc : pattern.GetOpDescs()) {
    if (desc->types[0] == TBE_PATTERN_INPUT_NODE) {
      continue;
    }
    if (desc->group_id != TBE_PATTERN_GROUPID_INVALID) {
      if (group_status[desc->group_id] == false) {
        FE_LOGD("group[%ld] not match", desc->group_id);
        status = false;
        break;
      }
    } else if (desc->repeate_curr < desc->repeate_min) {
      FE_LOGD("pattern not match info: desc name=[%s], curr_match cnt=[%ld], min_match cnt=[%ld]",
          desc->desc_name.c_str(), desc->repeate_curr, desc->repeate_min);
      status = false;
      break;
    }
  }

  return status;
}

/*
 * @brief: get fusiton pattern head desc matched
 * @param [in] node: graph node
 * @param [in] head_descs: candidated head desc list
 * @return BufferFusionOpDesc*: head desc ptr
 */
BufferFusionOpDesc *BufferFusionPassRunner::GetMatchedHeadDesc(ge::NodePtr node, const string &pattern_name,
                                                               std::vector<BufferFusionOpDesc *> head_descs) {
  for (auto desc : head_descs) {
    for (int test = 0; test < TBE_MATCH_LOOP_NUM; test++) {
      if (!test && (SkipDiffSizeDesc(node, desc, pattern_name) || SkipDiffShapeTypeDesc(node, desc))) {
        break;
      }
      if (IsOpTypeExist(node, desc)) {
        FE_LOGD("Node [%s], desc[%s] from graph has matched to head desc from fusion pattern.",
                node->GetName().c_str(), desc->desc_name.c_str());
        return desc;
      }
    }
  }
  return nullptr;
}

/*
 * @brief: get current loop desc matched
 * @param [in] node: graph node
 * @param [in] head_descs: valid head desc
 * @param [in] usage: record whether desc has beed matched
 * @return BufferFusionOpDesc*: matched desc ptr
 */
BufferFusionOpDesc *BufferFusionPassRunner::GetMatchedNormalDesc(
    ge::NodePtr node, BufferFusionOpDesc *head_desc, std::vector<BufferFusionOpDesc *> descs,
    std::map<BufferFusionOpDesc *, bool> usage,
    std::map<std::string, std::map<int32_t, std::vector<std::string>>> &matched_output_nodes,
    const string &pattern_name) {
  BufferFusionOpDesc *out_dst_desc = nullptr;
  std::string node_name = node->GetName();

  for (int test = 0; test < TBE_MATCH_LOOP_NUM; test++) {
    for (auto out_desc : descs) {
      if (SkipNodeForNormalDesc(matched_output_nodes, out_desc, node_name, node, head_desc, test, pattern_name)) {
        continue;
      }
      if (!usage[out_desc] && IsOpTypeExist(node, out_desc)) {
        if (IsOpTypeAny(out_desc->types) || IsOutputNode(out_desc->types)) {
          out_dst_desc = out_desc;
          continue;
        }
        FE_LOGD("match node name:%s, desc:%s", node_name.c_str(), out_desc->desc_name.c_str());
        return out_desc;
      }
    }

    if (out_dst_desc != nullptr) {
      FE_LOGD("match node name:%s, desc:%s", node->GetName().c_str(), out_dst_desc->desc_name.c_str());
      return out_dst_desc;
    }
  }

  return out_dst_desc;
}

void BufferFusionPassRunner::MatchFollowingNodes(
    ge::NodePtr node, std::vector<BufferFusionOpDesc *> &queue_descs, std::vector<ge::NodePtr> &queue_nodes,
    std::vector<BufferFusionOpDesc *> &curr_descs, BufferFusionPattern &pattern,
    std::map<BufferFusionOpDesc *, bool> &usage_flags, BufferFusionMapping &mapping, BufferFusionOpDesc *head_desc,
    std::map<std::string, std::map<int32_t, std::vector<std::string>>> &matched_output_nodes) {
  auto curr_nodes = node->GetOutDataNodes();
  for (auto desc : curr_descs) {
    usage_flags[desc] = false;
  }

  for (auto &out_node : curr_nodes) {
    std::string out_node_name = out_node->GetName();
    BufferFusionOpDesc *out_desc =
        GetMatchedNormalDesc(out_node, head_desc, curr_descs, usage_flags, matched_output_nodes, pattern.GetName());
    if (out_desc != nullptr) {
      if (NeedIgnoreOp(out_node) && !IsOpTypeAny(out_desc->types) && !IsOutputNode(out_desc->types)) {
        FE_LOGD("outDesc node [%s] is ignored, out_desc:%s", out_node_name.c_str(), out_desc->desc_name.c_str());
        continue;
      }
      if (!IsOpTypeAny(out_desc->types) && !IsOutputNode(out_desc->types)) {
        if (CheckLoopForward(mapping, out_node)) {
          continue;
        }
        queue_nodes.push_back(out_node);
        queue_descs.push_back(out_desc);
        auto it = matched_output_nodes.find(out_desc->desc_name);
        if (it != matched_output_nodes.end()) {
          (it->second)[out_desc->repeate_curr].push_back(out_node_name);
        } else {
          std::map<int32_t, std::vector<std::string>> temp;
          temp.insert(std::pair<int32_t, std::vector<std::string>>(out_desc->repeate_curr, {out_node_name}));
          matched_output_nodes.insert(
              std::pair<std::string, std::map<int32_t, std::vector<std::string>>>(out_desc->desc_name, temp));
        }
      }
      // add fusioned node to mapping
      mapping[out_desc].push_back(out_node);
      // repeat desc need to plus while has been matched
      if (CheckInt64AddOverflow(out_desc->repeate_curr, 1) != SUCCESS) {
        REPORT_FE_ERROR("[SubGraphOpt][UbFusion][MtcFollowNd] repeateCurr++ overflow. (out_desc:%s)",
                        out_desc->desc_name.c_str());
        return;
      }
      out_desc->repeate_curr++;
      usage_flags[out_desc] = true;
      if (queue_descs.front()->out_branch_type != TBE_OUTPUT_BRANCH_MULTI) {
        break;
      }
    } else {
      FE_LOGD("Output node [%s] has not been matched to any desc from fusion pattern.", out_node->GetName().c_str());
    }
  }
}

void BufferFusionPassRunner::GetExistingFusionScopes(ge::ComputeGraph &graph,
                                                     std::map<int64_t, vector<ge::NodePtr>> &fusion_scopes) {
  for (auto &node : graph.GetDirectNode()) {
    if (scope_allocator_ptr_->HasScopeAttr(node->GetOpDesc())) {
      int64_t scope_id = 0;
      if (scope_allocator_ptr_->GetScopeAttr(node->GetOpDesc(), scope_id) == false) {
        continue;
      }
      fusion_scopes[scope_id].push_back(node);
    }
  }
}

bool BufferFusionPassRunner::IsOptionalOutput(BufferFusionOpDesc *desc) {
  if (desc->out_branch_type > static_cast<int>(desc->outputs.size())) {
    FE_LOGW("%s outputs size is less than out_branch_type required, consider it as optional output.",
        desc->desc_name.c_str());
    return true;
  } else if (desc->out_branch_type == TBE_OUTPUT_BRANCH_SINGLE && desc->outputs.size() > 1) {
    for (auto out_desc : desc->outputs) {
      if (!IsOpTypeAny(out_desc->types) && !IsOutputNode(out_desc->types) && out_desc->repeate_min > 0) {
        continue;
      } else if (!IsOptionalOutput(out_desc)) {
        continue;
      }
      return true;
    }
    return false;
  } else {
    for (auto out_desc : desc->outputs) {
      if (!IsOpTypeAny(out_desc->types) && !IsOutputNode(out_desc->types) && out_desc->repeate_min > 0) {
        return false;
      } else if (!IsOptionalOutput(out_desc)) {
        return false;
      }
    }
    return true;
  }
}

bool BufferFusionPassRunner::CheckLoopForward(BufferFusionMapping &mapping, ge::NodePtr &targetnode) {
  std::vector<ge::NodePtr> all_fuse_nodes;
  for (const auto &it : mapping) {
    if (IsOpTypeAny(it.first->types) || IsOutputNode(it.first->types)) {
      continue;
    }
    for (const auto &node : it.second) {
      all_fuse_nodes.push_back(node);
    }
  }
  for (auto it = mapping.begin(); it != mapping.end(); it++) {
    for (auto node : it->second) {
      for (auto n : node->GetOutAllNodes()) {
        if (n == targetnode) {
          continue;
        }
        if (find(all_fuse_nodes.begin(), all_fuse_nodes.end(), n) != all_fuse_nodes.end()) {
          continue;
        }
        if (reachability_->IsReachable(n, targetnode)) {
          FE_LOGD("target node %s is a sub node of %s, a loop will be generated. skip it.",
                  targetnode->GetName().c_str(), n->GetName().c_str());
          return true;
        }
      }
    }
  }
  return false;
}

void BufferFusionPassRunner::CompareMappings(BufferFusionMapping &curr_mapping, BufferFusionMapping &longest_mapping,
                                             size_t &longest_num) {
  std::vector<ge::NodePtr> nodes;
  for (const auto &item : curr_mapping) {
    for (const auto &node : item.second) {
      nodes.push_back(node);
    }
  }
  if (nodes.size() > longest_num) {
    longest_mapping = curr_mapping;
    longest_num = nodes.size();
    FE_LOGD("set current mapping as the longest mapping. fused nodes number is %zu.",
        longest_num);
  }
}

void BufferFusionPassRunner::RecoverMappingAndQueue(
    vector<vector<BufferFusionOpDesc *>> &saved_queue_descs, vector<vector<ge::NodePtr>> &saved_queue_nodes,
    BufferFusionMappings &saved_mappings, vector<BufferFusionOpDesc *> &curr_queue_descs,
    vector<ge::NodePtr> &curr_queue_nodes, BufferFusionMapping &curr_mapping, bool match_error,
    BufferFusionMapping &longest_mapping, size_t &longest_num, BufferFusionPattern &pattern) {
  if (match_error) {
    curr_queue_descs.clear();
    curr_queue_nodes.clear();
  }
  if (curr_queue_descs.empty() && curr_queue_nodes.empty()) {
    if (GetPatternMatchStatus(pattern) != false && CheckAttrMatch(curr_mapping)) {
      CompareMappings(curr_mapping, longest_mapping, longest_num);
    }
    for (auto desc : pattern.GetOpDescs()) {
      if (curr_mapping.find(desc) != curr_mapping.end()) {
        desc->repeate_curr = 0;
      }
    }
    if (!saved_queue_descs.empty() && !saved_queue_nodes.empty() && !saved_mappings.empty()) {
      curr_queue_descs = saved_queue_descs.back();
      curr_queue_nodes = saved_queue_nodes.back();
      curr_mapping = saved_mappings.back();
      saved_queue_descs.pop_back();
      saved_queue_nodes.pop_back();
      saved_mappings.pop_back();
    } else {
      curr_mapping = longest_mapping;
    }
    for (auto desc : pattern.GetOpDescs()) {
      if (curr_mapping.find(desc) != curr_mapping.end()) {
        desc->repeate_curr = curr_mapping.find(desc)->second.size();
      }
    }
  }
}

bool BufferFusionPassRunner::SkipNodeForNormalDesc(
    std::map<std::string, std::map<int32_t, std::vector<std::string>>> &matched_output_nodes,
    BufferFusionOpDesc *out_desc, std::string node_name, ge::NodePtr node, BufferFusionOpDesc *head_desc,
    int64_t loop_num, const string &pattern_name) {
  auto it = matched_output_nodes.find(out_desc->desc_name);
  if (it != matched_output_nodes.end()) {
    if (find((it->second)[out_desc->repeate_curr].begin(), (it->second)[out_desc->repeate_curr].end(), node_name) !=
        (it->second)[out_desc->repeate_curr].end()) {
      FE_LOGD("skip matched node %s for opdesc %s.", node_name.c_str(), out_desc->desc_name.c_str());
      return true;
    }
  }
  // check the same size branch firstly, if not, check the diff size branch
  if (!out_desc->ignore_output_num && SkipDiffSizeDesc(node, out_desc, pattern_name)) {
    if (!loop_num ||
        (loop_num && !IsOpTypeAny(out_desc->types) && !IsOutputNode(out_desc->types) && !IsOptionalOutput(out_desc))) {
      return true;
    }
  }
  bool check_status = out_desc != head_desc && !out_desc->ignore_input_num &&
                      node->GetInDataNodes().size() != out_desc->inputs.size() && !IsOutputNode(out_desc->types);
  if (check_status) {
    FE_LOGD("node size not same with desc, node name=[%s], input cnt=[%zu], desc inputsize=[%zu]",
        node_name.c_str(), node->GetInDataNodes().size(), out_desc->inputs.size());
    return true;
  }

  if (SkipDiffShapeTypeDesc(node, out_desc)) {
    return true;
  }

  return false;
}

bool BufferFusionPassRunner::SkipNodeBeforeMatch(const ge::NodePtr &node, size_t curr_node_num, size_t curr_desc_num,
                                                 BufferFusionOpDesc *op_desc, BufferFusionOpDesc *head_desc,
                                                 bool get_output_result, const string &pattern_name) {
  if (!curr_node_num) {
    FE_LOGD("current node %s has no output node. skip it.", node->GetName().c_str());
    return true;
  }

  // One of the conditions for matching the longest structure is that the number of output nodes is no more than 10
  if (curr_node_num > 10) {
    FE_LOGD("output nodes[%d] of current node %s is greater then 5. skip it.", curr_node_num, node->GetName().c_str());
    return true;
  }

  if (!get_output_result) {
    FE_LOGD("fail to get output desc for %s. skip it.", op_desc->desc_name.c_str());
    return true;
  }
  if ((op_desc == head_desc || !op_desc->ignore_output_num) && curr_node_num > 1 &&
      (curr_node_num != curr_desc_num || op_desc->out_branch_type != TBE_OUTPUT_BRANCH_MULTI)) {
    string op_type = "";
    if (GetOpAttrType(node, op_type)) {
      if (pattern_name == "TbeCommonRules2FusionPass" &&  op_type == "Convolution") {
        FE_LOGD("dealwith [%s], conv is head node", pattern_name.c_str());
        return false;
      }
    }
    FE_LOGI("Not match info: out relation [%ld], outnode size [%zu], outdesc size [%zu]",
        op_desc->out_branch_type, curr_node_num, curr_desc_num);
    return true;
  }
  return false;
}

void BufferFusionPassRunner::SaveQueueBeforeMatch(std::vector<BufferFusionOpDesc *> &curr_descs, ge::NodePtr node,
                                                  BufferFusionOpDesc *op_desc,
                                                  std::vector<BufferFusionOpDesc *> &queue_descs,
                                                  std::vector<ge::NodePtr> &queue_nodes, BufferFusionMapping &mapping,
                                                  vector<vector<BufferFusionOpDesc *>> &saved_queue_descs,
                                                  vector<vector<ge::NodePtr>> &saved_queue_nodes,
                                                  BufferFusionMappings &saved_mappings, uint32_t &saved_count) {
  BufferFusionOpDesc *first_desc = nullptr;
  if (!curr_descs.empty()) {
    first_desc = curr_descs.front();
  }
  auto curr_nodes = node->GetOutDataNodes();
  if (first_desc && !first_desc->multi_output_skip_status.empty() &&
      first_desc->repeate_max > first_desc->repeate_curr &&
      first_desc->multi_output_skip_status[first_desc->repeate_curr] == SkipStatus::AVAILABLE &&
      curr_nodes.size() == 1 && curr_nodes.at(0)->GetOutDataNodes().size() > 1) {
    first_desc->multi_output_skip_status[first_desc->repeate_curr] = SkipStatus::SKIPPED;
    FE_LOGD("try skipping node %s from repeated opdesc %s first.", curr_nodes.at(0)->GetName().c_str(),
            first_desc->desc_name.c_str());
    curr_descs.erase(curr_descs.begin(), curr_descs.begin() + 1);
    saved_queue_descs.push_back(queue_descs);
    saved_queue_nodes.push_back(queue_nodes);
    saved_mappings.push_back(mapping);
    FE_LOGD("save queue for multioutputskip.");
  }

  if (!op_desc->ignore_output_num && op_desc->out_branch_type == TBE_OUTPUT_BRANCH_MULTI) {
    saved_queue_descs.push_back(queue_descs);
    saved_queue_nodes.push_back(queue_nodes);
    saved_mappings.push_back(mapping);
    saved_count++;
    FE_LOGD("save queue for multiple output branch.");
  }

  if (op_desc->ignore_output_num && curr_nodes.size() > 1) {
    saved_queue_descs.push_back(queue_descs);
    saved_queue_nodes.push_back(queue_nodes);
    saved_mappings.push_back(mapping);
    saved_count++;
    FE_LOGD("save queue for optional output.");
  }
}

void BufferFusionPassRunner::MatchFusionPattern(std::vector<BufferFusionOpDesc *> &queue_descs,
                                                std::vector<ge::NodePtr> &queue_nodes, BufferFusionPattern &pattern,
                                                BufferFusionMapping &mapping, BufferFusionOpDesc *head_desc) {
  // match all pattern descs from head desc
  BufferFusionMapping longest_mapping = mapping;
  size_t longest_num = 1;
  BufferFusionMappings saved_mappings;
  std::vector<std::vector<BufferFusionOpDesc *>> saved_queue_descs;
  std::vector<std::vector<ge::NodePtr>> saved_queue_nodes;
  std::map<std::string, std::map<int32_t, std::vector<std::string>>> matched_output_nodes;
  while (!queue_descs.empty() && !queue_nodes.empty()) {
    ge::NodePtr node = queue_nodes.front();
    BufferFusionOpDesc *op_desc = queue_descs.front();
    auto curr_nodes = node->GetOutDataNodes();
    std::vector<BufferFusionOpDesc *> curr_descs;
    bool res = pattern.GetOutputs(op_desc, curr_descs);
    if (SkipNodeBeforeMatch(node, curr_nodes.size(), curr_descs.size(), op_desc, head_desc, res, pattern.GetName())) {
      RecoverMappingAndQueue(saved_queue_descs, saved_queue_nodes, saved_mappings, queue_descs, queue_nodes, mapping,
                             true, longest_mapping, longest_num, pattern);
      continue;
    }
    if (curr_descs.empty() && queue_descs.size() > 1 && queue_nodes.size() > 1) {
      queue_nodes.erase(queue_nodes.begin());
      queue_descs.erase(queue_descs.begin());
      continue;
    }
    uint32_t saved_count = 0;
    SaveQueueBeforeMatch(curr_descs, node, op_desc, queue_descs, queue_nodes, mapping, saved_queue_descs,
                         saved_queue_nodes, saved_mappings, saved_count);
    std::map<BufferFusionOpDesc *, bool> usage_flags;
    // match head node's following nodes
    MatchFollowingNodes(node, queue_descs, queue_nodes, curr_descs, pattern, usage_flags, mapping, head_desc,
                        matched_output_nodes);

    // check whether match is ok
    bool match_status =
        GetCurrMatchStatus(!op_desc->ignore_output_num && curr_nodes.size() > 1, curr_descs, usage_flags);
    if (match_status == true) {
      queue_nodes.erase(queue_nodes.begin());
      queue_descs.erase(queue_descs.begin());
      RecoverMappingAndQueue(saved_queue_descs, saved_queue_nodes, saved_mappings, queue_descs, queue_nodes, mapping,
                             false, longest_mapping, longest_num, pattern);
    } else {
      for (uint32_t i = 0; i < saved_count; i++) {
        saved_queue_descs.pop_back();
        saved_queue_nodes.pop_back();
        saved_mappings.pop_back();
        FE_LOGD("remove last queue for failed match.");
      }
      RecoverMappingAndQueue(saved_queue_descs, saved_queue_nodes, saved_mappings, queue_descs, queue_nodes, mapping,
                             true, longest_mapping, longest_num, pattern);
    }
  }
}

Status BufferFusionPassRunner::MatchFromHead(const ge::NodePtr &node_g, BufferFusionPattern &pattern,
                                             BufferFusionMapping &mapping) {
  // get matched head desc
  BufferFusionOpDesc *head_desc = GetMatchedHeadDesc(node_g, pattern.GetName(), pattern.GetHead());
  std::vector<BufferFusionOpDesc *> queue_descs;
  std::vector<ge::NodePtr> queue_nodes;
  if (head_desc != nullptr) {
    mapping[head_desc].push_back(node_g);
    head_desc->repeate_curr++;
    queue_nodes.push_back(node_g);
    queue_descs.push_back(head_desc);
  } else {
    FE_LOGD("Node [%s] from graph has not been matched to any head desc from fusion pattern.",
            node_g->GetName().c_str());
    return FAILED;
  }
  // match fusion pattern from head node
  MatchFusionPattern(queue_descs, queue_nodes, pattern, mapping, head_desc);
  // check pattern status
  bool pattern_status = GetPatternMatchStatus(pattern);
  if (!pattern_status) {
    return FAILED;
  }
  if (!CheckAttrMatch(mapping)) {
    return FAILED;
  }
  return SUCCESS;
}

/*
 * @brief: match one pattern, and do fusion for the matched node
 * @param [in] graph: graph node
 * @param [in] pattern: fusion pattern info
 * @param [in] mappings: fusion group node set
 * @return bool: match current pattern ok or not
 */
bool BufferFusionPassRunner::RunOnePattern(ge::ComputeGraph &graph, BufferFusionPattern &pattern) {
  int matched_times = 0;
  string pass_name = GetName();
  string pattern_name = pattern.GetName();
  BufferFusionMapping mapping;
  using UbPassSliceInfoManagerPtr = std::shared_ptr<UbPassSliceInfoManager>;
  // 1. compare 1st pattern op and graph op(include compare op type and TBE type
  for (const ge::NodePtr &node_g : graph.GetDirectNode()) {
    // filter non TBE op
    if (NeedIgnoreOp(node_g)) {
      continue;
    }
    mapping.clear();

    // initial all descs repeat curr cnt
    InitRepeatCurr(pattern.GetOpDescs());

    if (MatchFromHead(node_g, pattern, mapping) != SUCCESS) {
      continue;
    }

    vector<ge::NodePtr> fusion_nodes;
    Status status = buffer_fusion_pass_base_ptr_->GetFusionNodes(mapping, fusion_nodes);
    if (status != SUCCESS) {
      REPORT_FE_ERROR("[SubGraphOpt][UB][RunOnePtn] Pass[%s]Pattern[%s]: Failed to get fusion nodes because %u.",
                      pass_name.c_str(), pattern_name.c_str(), status);
      return false;
    }

    if (fusion_nodes.empty()) {
      continue;
    }
    auto first_node = fusion_nodes.at(0);
    if (first_node == nullptr) {
      continue;
    }
    string fusion_op_slice_info;
    (void)ge::AttrUtils::GetStr(first_node->GetOpDesc(), FUSION_OP_SLICE_INFO, fusion_op_slice_info);
    if (fusion_op_slice_info.empty()) {
      UbPassSliceInfoManagerPtr ub_slice_info_manager_ptr;
      FE_MAKE_SHARED(ub_slice_info_manager_ptr = std::make_shared<UbPassSliceInfoManager>(), return false);
      ub_slice_info_manager_ptr->SetSliceInfoForFusionNodes(fusion_nodes);
    }

    // if nodes have cube and vector core type, do not need to fuse.
    if (CheckCubeVectorSplit(fusion_nodes)) {
      FE_LOGD("UbFusionPass[%s]: pattern=%s, headnode=%s. With cube and vector core type, do not need to fuse",
              GetName().c_str(), pattern_name.c_str(), node_g->GetName().c_str());
      continue;
    }

    // set scope_id
    SetScopeIdAndPassName(fusion_nodes, pass_name, pattern_name);
    FE_LOGD("UbFusionPass[%s]: pattern=%s, headnode=%s.", GetName().c_str(), pattern_name.c_str(),
        node_g->GetName().c_str());
    for (auto &item : fusion_nodes) {
      FE_LOGD("node:%s.", item->GetName().c_str());
    }
    reachability_->Update(graph, fusion_nodes);
    matched_times++;
  }

  FE_LOGD("UbFusionPass[%s]: pattern=%s, matched_times=%d", GetName().c_str(), pattern_name.c_str(), matched_times);
  return true;
}

void BufferFusionPassRunner::SetScopeIdAndPassName(const vector<ge::NodePtr> &fusion_nodes, const string &pass_name,
                                                   const string &pattern_name) {
  FE_LOGD("Fusion nodes' size: %d.", fusion_nodes.size());
  if (fusion_nodes.size() < 2) {
    return;
  }

  int64_t scope_id = scope_allocator_ptr_->AllocateScopeId();
  FE_LOGD("UBPass[pass_name=%s, pattern_name=%s]: set scope_id[%ld] for fusion_nodes.", pass_name.c_str(),
          pattern_name.c_str(), scope_id);
  for (const ge::NodePtr &node : fusion_nodes) {
    if (node == nullptr) {
      continue;
    }
    string name = node->GetName();
    if (scope_allocator_ptr_->SetScopeAttr(node->GetOpDesc(), scope_id)) {
      FE_LOGD("Node[%s]: set scope_id[%ld] success.", name.c_str(), scope_id);
    }
    if (ge::AttrUtils::SetStr(node->GetOpDesc(), PASS_NAME_ATTR, pass_name)) {
      FE_LOGD("Node[%s]: set pass_name[%s] success.", name.c_str(), pass_name.c_str());
    }
  }
}

bool BufferFusionPassRunner::CheckAttrMatch(BufferFusionMapping &mapping) {
  // node attr _stream_label must be equal
  auto fusion_nodes = buffer_fusion_pass_base_ptr_->GetMatchedNodes(mapping);
  string stream_label = "";
  for (auto n : fusion_nodes) {
    string stream_label_tmp = "";
    if (!ge::AttrUtils::GetStr(n->GetOpDesc(), STREAM_LABEL, stream_label_tmp)) {
      stream_label_tmp = "null";
      FE_LOGI("Fusion nodes do not have _stream_label attr.");
    }
    if (stream_label == "") {
      stream_label = stream_label_tmp;
    } else if (stream_label != "" && stream_label != stream_label_tmp) {
      FE_LOGD("_stream_label not equal, pattern matching failed.");
      return false;
    }
  }
  return true;
}

/*
 * @brief: init all pattern desc repeate_curr to 0
 * @param [in] pattern: fusion pattern desc
 * @return void */
void BufferFusionPassRunner::InitRepeatCurr(const std::vector<BufferFusionOpDesc *> &ops) {
  for (auto desc : ops) {
    desc->repeate_curr = 0;
    if (!desc->multi_output_skip_status.empty() &&
        desc->multi_output_skip_status[desc->repeate_min] != SkipStatus::DISABLED) {
      for (int64_t i = desc->repeate_min; i < desc->repeate_max; i++) {
        desc->multi_output_skip_status[i] = SkipStatus::AVAILABLE;
      }
    }
  }
}

bool BufferFusionPassRunner::CheckCubeVectorSplit(vector<ge::NodePtr> &fusion_nodes) {
  bool find_cube_op = false;
  bool find_vector_op = false;

  string soc_version = Configuration::Instance(AI_CORE_NAME).GetSocVersion();
  PlatformInfo platform_info;
  OptionalInfo opti_compilation_info;
  if (PlatformInfoManager::Instance().GetPlatformInfo(soc_version, platform_info, opti_compilation_info) != SUCCESS) {
    REPORT_FE_ERROR("[SubGraphOpt][UbFusion][ChkCubeVecSplit] Fail to get platform info by soc version[%s].",
                    soc_version.c_str());
    return false;
  }

  if (platform_info.ai_core_spec.cube_vector_split == 1) {
    for (auto &node : fusion_nodes) {
      auto iter = cube_op_type_.find(node->GetType());
      if (iter != cube_op_type_.end()) {
        find_cube_op = true;
      } else {
        find_vector_op = true;
      }
    }

    if (find_cube_op == true && find_vector_op == true) {
      return true;
    }
  }
  return false;
}
}  // namespace fe
